{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Simultaneous phase and amplitude aberration sensing with a vector-Zernike wavefront sensor\n",
    "\n",
    "We will introduce the classical Zernike wavefront sensor (ZWFS) and a way to reconstruct phase aberrations. Then we will introduce the vector-Zernike WFS (vZWFS) and show how this version allows for simultaneous phase and amplitude aberration sensing. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from hcipy import *\n",
    "import numpy as np\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
    "import os\n",
    "\n",
    "# For notebook animations\n",
    "from matplotlib import animation\n",
    "from IPython.display import HTML"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pupil_grid = make_pupil_grid(256, 1.5)\n",
    "\n",
    "aperture = make_magellan_aperture(True)\n",
    "\n",
    "telescope_pupil = aperture(pupil_grid)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The classical Zernike wavefront sensor is implemented in hcipy. A ZWFS is a focal plane optic, but in HCIPy it is implemented as a pupil plane to pupil plane propagation, similar to the vortex coronagraph. This ensures optimal calculation speed using matrix Fourier transforms (MFT). \n",
    "\n",
    "First, we create the ZWFS optical element. In principle, the only parameter you have to give is the pupil grid. Other parameters that influence the performance are:\n",
    "1. The phase step. For an optimal sensitivity, use $\\pi/2$.\n",
    "2. The phase dot diameter. For an optimal sensitivity, use 1.06 $\\lambda/D$. \n",
    "3. num_pix, sets the number of pixels the MFT uses.\n",
    "4. The pupil diameter.\n",
    "5. The reference wavelength.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ZWFS_ideal = ZernikeWavefrontSensorOptics(pupil_grid)\n",
    "ZWFS_non_ideal = ZernikeWavefrontSensorOptics(pupil_grid, phase_step=0.45 * np.pi, phase_dot_diameter=1.2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_ZWFS(wavefront_in, wavefront_out):\n",
    "    '''Plot the input wavefront and ZWFS response.\n",
    "\n",
    "    Parameters\n",
    "    ---------\n",
    "    wavefront_in : Wavefront\n",
    "        The aberrated wavefront coming in\n",
    "    wavefront_out : Wavefront\n",
    "        The wavefront_in propagated through the ZWFS\n",
    "    '''    \n",
    "\n",
    "    # Plotting the phase pattern and the PSF\n",
    "    fig = plt.figure()\n",
    "    ax1 = fig.add_subplot(131)\n",
    "    im1 = imshow_field(wavefront_in.amplitude, cmap='gray')\n",
    "    ax1.set_title('Input amplitude')\n",
    "    \n",
    "    divider = make_axes_locatable(ax1)\n",
    "    cax = divider.append_axes('right', size='5%', pad=0.05)\n",
    "    fig.colorbar(im1, cax=cax, orientation='vertical')\n",
    "\n",
    "    ax2 = fig.add_subplot(132)\n",
    "    im2 = imshow_field(wavefront_in.phase, cmap='RdBu')\n",
    "    ax2.set_title('Input phase')\n",
    "    \n",
    "    divider = make_axes_locatable(ax2)\n",
    "    cax = divider.append_axes('right', size='5%', pad=0.05)\n",
    "    fig.colorbar(im2, cax=cax, orientation='vertical')\n",
    "    \n",
    "    ax3 = fig.add_subplot(133)\n",
    "    im3 = imshow_field(wavefront_out.intensity, cmap='gray')\n",
    "    ax3.set_title('Output intensity')\n",
    "    \n",
    "    divider = make_axes_locatable(ax3)\n",
    "    cax = divider.append_axes('right', size='5%', pad=0.05)\n",
    "    fig.colorbar(im3, cax=cax, orientation='vertical')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now it is time to measure the wavefront aberrations using the ZWFS. We create a random phase aberration with a power law distribution and propagate it through the ZWFS. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Creating the aberrated wavefront\n",
    "phase_aberrated = make_power_law_error(pupil_grid, 0.2, 1)\n",
    "phase_aberrated -= np.mean(phase_aberrated[telescope_pupil >= 0.5])\n",
    "wf = Wavefront(telescope_pupil * np.exp(1j * phase_aberrated))\n",
    "\n",
    "# Applying the ZWFS\n",
    "wf_out = ZWFS_ideal.forward(wf)\n",
    "\n",
    "plot_ZWFS(wf, wf_out)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "From the intensity of the outcoming wavefront, it is clear that the intensity is dependent on the input phase. To show a simple reconstruction, we use the reconstruction algorithm of N'Diaye et al. 2013 [1]:\n",
    "\n",
    "\n",
    "$$\\phi = −1 + \\sqrt{2I_c}$$,\n",
    "\n",
    "where $\\phi$ is the phase and $I_c$ the measured intensity. \n",
    "\n",
    "[1] N'Diaye et al. \"Calibration of quasi-static aberrations in exoplanet direct-imaging instruments with a Zernike phase-mask sensor\", Astronomy & Astrophysics 555 (2013)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_reconstruction_phase(phase_in, phase_out, telescope_pupil):\n",
    "    '''Plot the incoming aberrated phase pattern and the reconstructed phase pattern\n",
    "    \n",
    "    Parameters\n",
    "    ---------\n",
    "    phase_in : Field\n",
    "        The phase of the aberrated wavefront coming in\n",
    "    phase_out : Field\n",
    "        The phase of the aberrated wavefront as reconstructed by the ZWFS\n",
    "    '''    \n",
    "    \n",
    "    # Calculating the difference of the reconstructed phase and input phase\n",
    "    diff = phase_out - phase_in\n",
    "    diff -= np.mean(diff[telescope_pupil >= 0.5])\n",
    "\n",
    "    # Plotting the phase pattern and the PSF\n",
    "    fig = plt.figure()\n",
    "    ax1 = fig.add_subplot(131)\n",
    "    im1 = imshow_field(phase_in, cmap='RdBu', vmin=-0.2, vmax=0.2, mask=telescope_pupil)\n",
    "    ax1.set_title('Input phase')\n",
    "    \n",
    "    divider = make_axes_locatable(ax1)\n",
    "    cax = divider.append_axes('right', size='5%', pad=0.05)\n",
    "    fig.colorbar(im1, cax=cax, orientation='vertical')\n",
    "\n",
    "    ax2 = fig.add_subplot(132)\n",
    "    im2 = imshow_field(phase_out, cmap='RdBu', vmin=-0.2, vmax=0.2, mask=telescope_pupil)\n",
    "    ax2.set_title('Reconstructed phase')\n",
    "    \n",
    "    divider = make_axes_locatable(ax2)\n",
    "    cax = divider.append_axes('right', size='5%', pad=0.05)\n",
    "    fig.colorbar(im2, cax=cax, orientation='vertical')\n",
    "    \n",
    "    ax3 = fig.add_subplot(133)\n",
    "    im3 = imshow_field(diff, cmap='RdBu', vmin=-0.02, vmax=0.02, mask=telescope_pupil)\n",
    "    ax3.set_title('Difference')\n",
    "    \n",
    "    divider = make_axes_locatable(ax3)\n",
    "    cax = divider.append_axes('right', size='5%', pad=0.05)\n",
    "    fig.colorbar(im3, cax=cax, orientation='vertical')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "phase_est = -1 + np.sqrt(2 * wf_out.intensity)\n",
    "\n",
    "plot_reconstruction_phase(phase_aberrated, phase_est, telescope_pupil)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Using the simplified formula, we reconstruct the phase to a large extend and most of the structure in the phase pattern is reconstructed. However, the difference is larger than 10% of the original phase amplitude, and is structured in the shape similar to defocus. \n",
    "\n",
    "In N'Diaye et al. 2013, they propose a more accurate reconstruction: \n",
    "\n",
    "$$\\phi = \\sqrt{-1 + 3 − 2b −(1−I_c)/b},$$\n",
    "\n",
    "The $b$ is a term correction factor that takes into account the apodization from the Zernike mask. It can be approximated using a Strehl estimate, $S$, the Zernike mask geometry, $M$, and the telescope pupil geometry $P_0$:\n",
    "\n",
    "\n",
    "$$b \\simeq S \\widehat{M} \\otimes P_0 = S b_0.$$\n",
    "\n",
    "Here $\\widehat{M}$ is the Fourier transform of the Zernike mask geometry, $M$.\n",
    "\n",
    "In advance, the Strehl ratio is unknown. We set it to 1. Then we calculate the mask.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "S = 1\n",
    "\n",
    "# Creating the mask M\n",
    "focal_grid = make_focal_grid(20, 2)\n",
    "\n",
    "M = Apodizer(circular_aperture(1.06)(focal_grid))\n",
    "\n",
    "# Calculating b0\n",
    "prop = FraunhoferPropagator(pupil_grid, focal_grid)\n",
    "b0 = prop.backward(M(prop.forward(Wavefront(telescope_pupil)))).electric_field.real\n",
    "\n",
    "# Calculating b\n",
    "b = np.sqrt(S) * b0\n",
    "\n",
    "# Estimating the phase using the equations from above \n",
    "phase_est = -1 + np.sqrt(np.abs(3 - 2 * b - (1 - wf_out.intensity) / b))\n",
    "\n",
    "plot_reconstruction_phase(phase_aberrated, phase_est, telescope_pupil)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It is clear that the ZWFS can accurately reconstruct the phase using the correction factor. \n",
    "\n",
    "However, one important limiting factor of the ZWFS reconstruction are amplitude errors, to which it is blind. This means that it reconstructs them as phase errors. To show this effect, we will propagate the phase aberration over a small distance using Fresnel propagation. \n",
    "\n",
    "In general, a ZWFS is least sensitive to tip-tilt modes and residual tip-tilt aberrations will dominate the reconstructed phase. In this notebook we will take the liberty to remove these tip-tilt modes. We also remove piston, as it is not important to reconstruct absolute phase. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Setting up the Fresnel propagator\n",
    "prop_extra = FresnelPropagator(pupil_grid, distance=1e-4)\n",
    "\n",
    "# Creating the power law error\n",
    "phase_aberrated = make_power_law_error(pupil_grid, 0.3, 1)\n",
    "phase_aberrated -= np.mean(phase_aberrated[telescope_pupil >= 0.5])\n",
    "\n",
    "# Removing the piston and tip-tilt modes.\n",
    "zbasis = make_zernike_basis(3, 1, pupil_grid)\n",
    "for test in zbasis:\n",
    "    test*= telescope_pupil\n",
    "    phase_aberrated -= test * np.dot(phase_aberrated, test) / np.dot(test, test)\n",
    "\n",
    "# Use super-Gaussian to avoid edge effects\n",
    "p = telescope_pupil * np.exp(-(pupil_grid.as_('polar').r / 0.68)**20)\n",
    "wf_new = prop_extra(Wavefront(p * np.exp(1j * phase_aberrated)))\n",
    "\n",
    "# Updating reference phase and multiplying with the telescope pupil \n",
    "phase_aberrated = (wf_new.phase - np.pi) % (2 * np.pi) - np.pi\n",
    "wf_new.electric_field[telescope_pupil < 0.5] = 0\n",
    "\n",
    "# Plotting the aberration\n",
    "plt.figure()\n",
    "plt.subplot(1,2,1)\n",
    "imshow_field((wf_new.amplitude * telescope_pupil) - 1, vmin=-0.02, vmax=0.02, cmap='gray')\n",
    "plt.title('Amplitude')\n",
    "plt.subplot(1,2,2)\n",
    "imshow_field(wf_new.phase*telescope_pupil, vmin=-0.2, vmax=0.2, cmap='RdBu', mask=telescope_pupil)\n",
    "plt.title('Phase')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The Fresnel propagation added mostly high-frequency ampltitude variations and changed the phase a bit. To see the effect of these aberrations on the ZWFS, we redo the reconstruction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wf_out = ZWFS_ideal.forward(wf_new.copy())\n",
    "\n",
    "S = 1\n",
    "\n",
    "b = np.sqrt(S) * b0\n",
    "\n",
    "phase_est = -1 + np.sqrt(np.abs(3 - 2 * b0 - (1 - wf_out.intensity) / b0))\n",
    "\n",
    "plot_reconstruction_phase(phase_aberrated, phase_est, telescope_pupil)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As expected, the amplitude aberrations have been reconstructed as phase aberrations. The performance is worse and a lot of high-frequency noise has been added to the reconstructed phase. \n",
    "\n",
    "The rest of this notebook will introduce the vector-Zernike wavefront sensor, how the VZWFS output can be analysed and demonstrate the simultaneous phase and amplitude reconstruction. \n",
    "\n",
    "We start by creating the vZWFS optical element, similar to the ZWFS."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vZWFS_ideal = VectorZernikeWavefrontSensorOptics(pupil_grid, num_pix=128)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To demonstrate the reconstructed amplitude, we generate the same plot function for amplitude as has been created for the phase.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_reconstruction_amplitude(amplitude_in, amplitude_out, telescope_pupil):\n",
    "    '''Plot the incoming aberrated amplitude pattern and the reconstructed amplitude pattern\n",
    "    \n",
    "    Parameters\n",
    "    ---------\n",
    "    amplitude_in : Field\n",
    "        The phase of the aberrated wavefront coming in\n",
    "    amplitude_out : Field\n",
    "        The amplitude of the aberrated wavefront as reconstructed by the vZWFS\n",
    "    '''    \n",
    "\n",
    "    amplitude_in = amplitude_in - 1\n",
    "    amplitude_out = amplitude_out - 1\n",
    "    \n",
    "    # Plotting the phase pattern and the PSF\n",
    "    fig = plt.figure()\n",
    "    ax1 = fig.add_subplot(131)\n",
    "    im1 = imshow_field(amplitude_in, cmap='gray', vmin=-0.05, vmax=0.05, mask=telescope_pupil)\n",
    "    ax1.set_title('Input amplitude')\n",
    "    \n",
    "    divider = make_axes_locatable(ax1)\n",
    "    cax = divider.append_axes('right', size='5%', pad=0.05)\n",
    "    fig.colorbar(im1, cax=cax, orientation='vertical')\n",
    "\n",
    "    ax2 = fig.add_subplot(132)\n",
    "    im2 = imshow_field(amplitude_out, cmap='gray', vmin=-0.05, vmax=0.05, mask=telescope_pupil)\n",
    "    ax2.set_title('Reconstructed amplitude')\n",
    "    \n",
    "    divider = make_axes_locatable(ax2)\n",
    "    cax = divider.append_axes('right', size='5%', pad=0.05)\n",
    "    fig.colorbar(im2, cax=cax, orientation='vertical')\n",
    "    \n",
    "    ax3 = fig.add_subplot(133)\n",
    "    im3 = imshow_field(amplitude_out - amplitude_in, cmap='gray', vmin=-0.01, vmax=0.01, mask=telescope_pupil)\n",
    "    ax3.set_title('Difference')\n",
    "    \n",
    "    divider = make_axes_locatable(ax3)\n",
    "    cax = divider.append_axes('right', size='5%', pad=0.05)\n",
    "    fig.colorbar(im3, cax=cax, orientation='vertical')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now comes the most important part. The vZWFS applies $+\\frac{\\pi}{2}$ and $-\\frac{\\pi}{2}$ to the two orthogonal circular polarization states respectivily. Therefore, we have to add a circular polarizing beam splitter that separates these two polarization states. \n",
    "\n",
    "Then we follow the equations in Doelman et al. 2019 [1] to reconstruct the phase. We use both pupil intensities and combine these to calculate the phase and amplitude (or rather the square root of the intensity) from one snap shot.\n",
    "\n",
    "We start off by assuming a Strehl of 1. However, it is possible to update our estimate of b after calculating the phase and amplitude. We update the wavefront estimate that is used to calculate b and repeat the process. This allows for a slightly more accurate phase reconstruction. \n",
    "\n",
    "[1] Doelman, David S., et al. \"Simultaneous phase and amplitude aberration sensing with a liquid-crystal vector-Zernike phase mask.\" Optics letters 44.1 (2019)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def reconstructor(wavefront):\n",
    "    '''Plot the incoming aberrated phase pattern and the reconstructed phase pattern\n",
    "    \n",
    "    Parameters\n",
    "    ---------\n",
    "    phase_in : Field\n",
    "        The phase of the aberrated wavefront coming in\n",
    "    phase_out : Field\n",
    "        The phase of the aberrated wavefront as reconstructed by the ZWFS\n",
    "    '''\n",
    "    # Separate circular polarization states\n",
    "    CPBS = CircularPolarizingBeamSplitter()\n",
    "    wf_ch1, wf_ch2 = CPBS.forward(wavefront.copy())\n",
    "    I_L = wf_ch1.I \n",
    "    I_R = wf_ch2.I\n",
    "\n",
    "    # Creating masks for reconstruction\n",
    "    M = Apodizer(circular_aperture(1.06)(focal_grid))   \n",
    "    b0 = np.abs(prop.backward(M(prop.forward(Wavefront(telescope_pupil)))).electric_field)\n",
    "    S = 1\n",
    "    b = np.sqrt(S) * b0\n",
    "\n",
    "    # Calculating the phase and amplitude aberrations.\n",
    "    for i in range(2):\n",
    "        amp_est = np.nan_to_num(np.sqrt(I_L + I_R + np.sqrt(np.abs(4 * b**2 * (I_R + I_L) - (I_R - I_L)**2 - 4 * b**4))))\n",
    "        phase_est = np.arcsin(I_L - I_R) / (2 * amp_est * b)\n",
    "        \n",
    "        #Updating b for improved estimate\n",
    "        wf_est = Wavefront(amp_est * telescope_pupil * np.exp(1j * phase_est))\n",
    "        b = prop.backward(M(prop.forward(wf_est))).electric_field.real\n",
    "        \n",
    "    return amp_est, phase_est\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "amp_est, phase_est = reconstructor(vZWFS_ideal(wf_new.copy()))\n",
    "I_est = amp_est**2\n",
    "\n",
    "plot_reconstruction_phase(wf_new.phase, phase_est, telescope_pupil)\n",
    "plot_reconstruction_amplitude(wf_new.I, I_est, telescope_pupil)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Both the phase and amplitude are reconstructed without cross-talk of phase and amplitude. This is the great advantage of a vector-Zernike wavefront sensor. "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  },
  "level": "intermediate",
  "thumbnail_figure_index": 6
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
